import numpy as np
import pandas as pd


# 接收一个df

def ATR(data_df, period=14):
    history_candles = data_df.copy()
    history_candles.sort_values(by="ts", ascending=True, inplace=True)
    history_candles.reset_index(drop=True, inplace=True)
    # 假设df是已经存在的DataFrame，并且包含列名"open", "high", "low", "close"
    history_candles.rename(columns={'open': 'o', 'high': 'h', 'low': 'l', 'close': 'c'}, inplace=True)

    # history_candles.set_index('ts',inplace = True)
    # ATR指标   TR的平均数
    for i in range(0, len(history_candles)):
        c_high = history_candles['h'][i]
        c_low = history_candles['l'][i]
        o_close = history_candles['c'].shift(1)[i]
        history_candles.loc[history_candles.index[i], 'TR'] = max((c_high - c_low), abs((c_high - o_close)),
                                                                  abs((c_low - o_close)))
    history_candles['ATR'] = history_candles['TR'].rolling(period).mean()
    return history_candles


if __name__ == '__main__':
    data_df = pd.read_csv('../data_info/bit_coin.csv', index_col=0)
    result = ATR(data_df)
    result.to_csv("./ATR.csv", index=False)
